import numpy as np
from scipy.linalg import block_diag
from scipy import optimize
from os import path, mkdir
from numpy.fft import fftn, ifftn, ifftshift
import matplotlib.pyplot as plt
import time

# plt.rcParams['font.family']= 'Times New Roman'
# plt.rcParams['mathtext.it']= 'Times New Roman'
# plt.rcParams['mathtext.rm']= 'Times New Roman'
plt.rcParams['font.size'] = 20

class Band_Hamiltonian(): #spin_valley projected

    def __init__(self,SWM,N_lay,kx,ky,U,q, H_renorm, **kwargs): #q=+1(-1) for K(K') valley
        self.N_lay=N_lay
        self.kx=kx
        self.ky=ky
        self.q=q
        self.U=U 
        self.SWM = SWM
        self.H_renorm = H_renorm

    def Pi(self): #pi oprator q*kx+i*ky # shape(nk,nk)
        return self.q*(self.kx)+1j*(self.ky)
    
    def v(self,gamma):#in unit of a/h_bar (a=2.46 A;lattice constant)
        return np.sqrt(3)*gamma/2
    
    def Hamiltonian(self):


        H=np.zeros([2,2],dtype=complex)

        hv0a = self.v(self.SWM['gamma_0'])
        hv3a = self.v(self.SWM['gamma_3'])
        hv4a = self.v(self.SWM['gamma_4'])
        gm1  = self.SWM['gamma_1']
        gm2  = self.SWM['gamma_2']
        N    = self.N_lay
                      
        r_gap = 0
        for i in range(0,N-1):
            r_gap = r_gap + (1-2*i/(N-1))*(np.abs(hv0a*np.conj(self.Pi())/gm1)**(2*i))

        r_N   = (np.abs(hv0a*np.conj(self.Pi())/gm1)**(2*N)-1)/ \
                (np.abs(hv0a*np.conj(self.Pi())/gm1)**2-1)
        r_Nm1 = (np.abs(hv0a*np.conj(self.Pi())/gm1)**(2*N-2)-1)/ \
                (np.abs(hv0a*np.conj(self.Pi())/gm1)**2-1)        

        H[0,1] = H[0,1] + (-gm1)*(hv0a*np.conj(self.Pi())/(-gm1))**N
        H[1,0] = H[1,0] + (-gm1)*(hv0a*self.Pi()         /(-gm1))**N
        # H_ch

        H[0,0] = H[0,0] + self.SWM['delta'] - 2*np.abs(self.Pi())**2*hv0a*hv4a/gm1 * r_Nm1
        H[1,1] = H[1,1] + self.SWM['delta'] - 2*np.abs(self.Pi())**2*hv0a*hv4a/gm1 * r_Nm1
        # H_s

        H[0,1] = H[0,1] + ( (N-2)*gm2/2 - (N-1)*hv0a*hv3a*np.abs(self.Pi())**2/gm1 ) * (hv0a*np.conj(self.Pi())/(-gm1))**(N-3)
        H[1,0] = H[1,0] + ( (N-2)*gm2/2 - (N-1)*hv0a*hv3a*np.abs(self.Pi())**2/gm1 ) * (hv0a*        self.Pi() /(-gm1))**(N-3)
        # H_tr


        if self.H_renorm == True:
            H[0,0] = H[0,0] + self.U/2 * r_gap
            H[1,1] = H[1,1] - self.U/2 * r_gap
            # H_gap
            H = H/r_N
        else:
            H[0,0] = H[0,0] + self.U/2
            H[1,1] = H[1,1] - self.U/2
            

        return H # Final Hamiltonian
    
#########################################################################################################

# Hamiltonians and Eigen States (Non-int & Seed state)
class Hamiltonian_solver_sweep(Band_Hamiltonian):
    def __init__(self,SWM,kmax,dk,N_lay,U,seed_val, H_renorm, **kwargs):
        self.N_lay=N_lay
        self.U=U
        self.seed_val = seed_val
        self.SWM = SWM
        self.nk=int(2*kmax/dk)+1
        self.H_renorm = H_renorm

        # Declare the Pauli matrices as global
        global V_x, V_z, S_x, S_z, L_x, L_z, eye, N_band

        # 8 band structure = spin(up,dn) x valley(K,K') x orbital(1A,3B)
        # (order from the largest block to the smallest)
        
        N_band = 2*2*2

        # Spin matrices
        sig_0=np.array([[1,0],[0,1]])
        sig_x=np.array([[0,1],[1,0]])
        sig_y=np.array([[0,-1j],[1j,0]])
        sig_z=np.array([[1,0],[0,-1]])

        eye = np.kron(sig_0,np.kron(sig_0,sig_0))
        S_x = np.kron(sig_x,np.kron(sig_0,sig_0))
        S_z = np.kron(sig_z,np.kron(sig_0,sig_0))
        V_x = np.kron(sig_0,np.kron(sig_x,sig_0))
        V_z = np.kron(sig_0,np.kron(sig_z,sig_0))
        L_x = np.kron(sig_0,np.kron(sig_0,sig_x))
        L_z = np.kron(sig_0,np.kron(sig_0,sig_z))

    def Block_Hamiltonian(self,kx,ky,SOC,SOC_dir): #shape(8*layer,8*layer); sublattice,valley,spin
        self.mlg_K1 = Band_Hamiltonian(self.SWM, self.N_lay, kx, ky, self.U, +1, self.H_renorm).Hamiltonian() #K  valley; shape(2*layer,2*layer)
        self.mlg_K2 = Band_Hamiltonian(self.SWM, self.N_lay, kx, ky, self.U, -1, self.H_renorm).Hamiltonian() #K' valley; shape(2*layer,2*layer)
        H_valley_spin = block_diag(self.mlg_K1, self.mlg_K2)
        H_valley_spin = np.kron(np.identity(2),H_valley_spin)
        return H_valley_spin+SOC*V_z@(SOC_dir * S_x + (1-SOC_dir) * S_z)@L_z#tuning spin-orbit coupling
    
    def seed_Hamiltonian(self,kx,ky,SOC,SOC_dir):#(K_up,K'_up,K_down,K'_down)
        H_valley_spin  = self.Block_Hamiltonian(kx,ky,SOC,SOC_dir)
        H_valley_spin += self.seed_val["K_up"]  * (eye+V_z)@(eye+S_z)/4 + \
                         self.seed_val["K_dn"]  * (eye+V_z)@(eye-S_z)/4 + \
                         self.seed_val["Kp_up"] * (eye-V_z)@(eye+S_z)/4 + \
                         self.seed_val["Kp_dn"] * (eye-V_z)@(eye-S_z)/4 + \
                         self.seed_val["vx"]  * V_x + \
                         self.seed_val["sx"]  * S_x + \
                         self.seed_val["vsx"] * V_x@S_x
        return H_valley_spin
    
    def nonint_H_string(self,kx_,ky_,SOC,SOC_dir): #array of Block Hamiltonian at each k-pt.
        H = np.apply_along_axis(lambda k: self.Block_Hamiltonian(k[0], k[1], SOC,SOC_dir)\
                                           , 1, np.concatenate((kx_, ky_), axis=1)) #shape=(Nk,N_band,N_band)
        return H
    
    def Eigen(self,nk,kx_,ky_,SOC,SOC_dir,y): #Eigenspectrum
        N_band=len(self.Block_Hamiltonian(0,0,SOC,SOC_dir))
        if y==0: #non-interacting energy and eigenstates
            H=np.apply_along_axis(lambda k: self.Block_Hamiltonian(k[0], k[1],SOC,SOC_dir)\
                                           , 1, np.concatenate((kx_, ky_), axis=1)) #shape=(Nk,N_band,N_band)
            eigen_values,eigen_vectors=np.linalg.eigh(H)
            return eigen_values.reshape(nk,nk,N_band),eigen_vectors.reshape(nk,nk,N_band,N_band)
        elif y==1: #Seed energy and eigenstates
            H = np.apply_along_axis(lambda k: self.seed_Hamiltonian(k[0], k[1],SOC,SOC_dir)\
                                           , 1, np.concatenate((kx_, ky_), axis=1)) #shape=(Nk,N_band,N_band)
            eigen_values,eigen_vectors=np.linalg.eigh(H)
            return eigen_values.reshape(nk,nk,N_band),eigen_vectors.reshape(nk,nk,N_band,N_band)
        
#########################################################################################################

# Hamiltonians and Eigen States (Non-int & Seed state)
class SCF(Hamiltonian_solver_sweep):
    def __init__(self,U,consts,**kwargs):
        self.ne_val=consts["ne_val"]
        self.U=U

        self.N_lay=consts["N_lay"]
        self.SWM = consts["SWM"]
        self.rmt_SWM = consts["rmt_SWM"]
        
        self.SWM["gamma_2"] = self.SWM["gamma_2"] * self.rmt_SWM
        self.SWM["gamma_3"] = self.SWM["gamma_3"] * self.rmt_SWM
        self.SWM["gamma_4"] = self.SWM["gamma_4"] * self.rmt_SWM
        self.SWM["delta"]   = self.SWM["delta"]   * self.rmt_SWM

        self.dk = consts["dk"]
        self.kmax = consts["kmax"]
        a=2.46e-8
        self.A=(2*np.pi*a/self.dk)**2 #area in cm^2
        self.nk=int(2*self.kmax/self.dk)+1
        k0=np.linspace(-self.kmax,self.kmax,self.nk)
        self.kx,self.ky=np.meshgrid(k0,k0)
        k_diff=np.sqrt(self.kx**2+self.ky**2)/a #in cm^-1; will be used in Fock matrix
        assert k_diff.shape==(self.nk,self.nk)        

        self.V_type = consts["V_type"]
        self.V0     = consts["V0"]
        self.ke = consts["ke"]
        self.er=consts["er"]
        self.d_gate=consts["d_gate"]
        self.alp = consts["alp"]
        self.Hund = consts["Hund"]
        self.beta = consts["beta"]

        self.SOC=consts["SOC"]
        self.SOC_dir=consts["SOC_dir"]        

        self.max_count=consts["max_count"]
        self.tol=consts["tol"]
        self.mix=consts["mix"]
        self.seed_val=consts["seed_val"]        
        
        self.run_idx = consts["run_idx"]

        self.H_renorm = consts["H_renorm"]
        self.fix_sym = consts["fix_sym"]
        self.calc_type = consts["calc_type"]
        
        self.V_q=2*np.pi*self.ke*np.tanh((k_diff+1e-15)*self.d_gate)/(self.er*(k_diff+1e-15))
        # Long-range Coloumb potential in k-space  

    """"Functions: Fermi-Dirac Distribution, Density matrix, Chemical Potential solver"""
    def Fermi_Dirac(self,E,Mu):#Electron Occupation at each k-pt for each band
        np.seterr(over='ignore',divide='ignore')
        nF=1/(1+np.exp(self.beta*(E-Mu)))#shape(nk,nk,N_band)
        return nF #output shape (nk,nk,N_band)
    
    def get_rho(self,vec,distribution):#E.shape(nk,nk,N_band),vec.shape=(nk,nk,N_band,N_band)
        output=np.einsum('ijm,ijkm,ijlm->ijkl',distribution,vec,np.conj(vec)) #summed over bands
        return output #output shape(nk,nk,N_band,N_band)
    
    def get_mu(self,E,ne): #Chemical potential solver
        Ne=ne*self.A
        def find_mu(Mu):
            return np.real(np.sum(self.Fermi_Dirac(E,Mu)))-(self.nk**2)*int(N_band/2)-Ne
        return optimize.bisect(find_mu, -200, 200)

#   Fock Matrix (using FFT) 
    def get_Fock(self,rho_k):
        np.seterr(divide='ignore')

        if self.V_type == "long":
            V_r=fftn(self.V_q)
            assert V_r.shape==(self.nk,self.nk)
            rho_r  = fftn(rho_k,axes=(0,1))
            Fock_r = -np.einsum('ijkl,ij->ijkl',rho_r,V_r)/self.A
            Fock_k = ifftshift(ifftn(Fock_r,axes=(0,1)),axes=(0,1)) #shape==(nk,nk,N_band,N_band)
            V_VI   = -(2*np.pi*self.ke*self.d_gate/self.er) * np.sum(rho_k,axis=(0,1)) / self.A
        else:
            Fock_k = -self.V0 * np.sum(rho_k,axis=(0,1)) / self.A
            V_VI   = Fock_k
            Fock_k = np.tile(Fock_k,[self.nk**2,1,1]).reshape(self.nk,self.nk,N_band,N_band) #in meV

        N_BL = int(N_band/4)

        idx_1 = np.ix_(np.r_[(0*N_BL):(1*N_BL),(2*N_BL):(3*N_BL)],\
                       np.r_[(0*N_BL):(1*N_BL),(2*N_BL):(3*N_BL)])
        idx_2 = np.ix_(np.r_[(1*N_BL):(2*N_BL),(3*N_BL):(4*N_BL)],\
                       np.r_[(1*N_BL):(2*N_BL),(3*N_BL):(4*N_BL)])
        Fock_k_VEx = np.zeros([N_band,N_band], dtype=complex)
        Fock_k_VEx[idx_1] = V_VI[idx_2]
        Fock_k_VEx[idx_2] = V_VI[idx_1]

        # Fock_Hund = np.zeros([N_band,N_band], dtype=complex)
        # Fock_Hund[(0*N_BL):(1*N_BL),(0*N_BL):(1*N_BL)] = V_VI[(1*N_BL):(2*N_BL),(1*N_BL):(2*N_BL)] - V_VI[(3*N_BL):(4*N_BL),(3*N_BL):(4*N_BL)]
        # Fock_Hund[(1*N_BL):(2*N_BL),(1*N_BL):(2*N_BL)] = V_VI[(0*N_BL):(1*N_BL),(0*N_BL):(1*N_BL)] - V_VI[(2*N_BL):(3*N_BL),(2*N_BL):(3*N_BL)]
        # Fock_Hund[(2*N_BL):(3*N_BL),(2*N_BL):(3*N_BL)] = V_VI[(3*N_BL):(4*N_BL),(3*N_BL):(4*N_BL)] - V_VI[(1*N_BL):(2*N_BL),(1*N_BL):(2*N_BL)]
        # Fock_Hund[(3*N_BL):(4*N_BL),(3*N_BL):(4*N_BL)] = V_VI[(2*N_BL):(3*N_BL),(2*N_BL):(3*N_BL)] - V_VI[(0*N_BL):(1*N_BL),(0*N_BL):(1*N_BL)]
        # # Sz(K) * Sz(K')
        # Fock_Hund[(0*N_BL):(1*N_BL),(2*N_BL):(3*N_BL)] = 2 * V_VI[(3*N_BL):(4*N_BL),(1*N_BL):(2*N_BL)]
        # Fock_Hund[(1*N_BL):(2*N_BL),(3*N_BL):(4*N_BL)] = 2 * V_VI[(2*N_BL):(3*N_BL),(0*N_BL):(1*N_BL)]
        # Fock_Hund[(2*N_BL):(3*N_BL),(0*N_BL):(1*N_BL)] = 2 * V_VI[(1*N_BL):(2*N_BL),(3*N_BL):(4*N_BL)]
        # Fock_Hund[(3*N_BL):(4*N_BL),(1*N_BL):(2*N_BL)] = 2 * V_VI[(0*N_BL):(1*N_BL),(2*N_BL):(3*N_BL)]
        # # Sx(K) * Sx(K') + Sy(K) * Sy(K')
        # Fock_k = Fock_k + self.alp * Fock_k_VEx + self.Hund * Fock_Hund

        Fock_k = Fock_k + self.alp * Fock_k_VEx
        
        match self.fix_sym:
            case "su4":
                tFock_k = (             Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] +\
                        np.conj(np.flip(Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)],(0,1))) +\
                                        Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)] +\
                        np.conj(np.flip(Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)],(0,1))))/4
                
                Fock_k = np.zeros([self.nk,self.nk,N_band,N_band], dtype=complex)
                Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] = tFock_k
                Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))
                Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)] = tFock_k
                Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))                        
                
            case "su2":
                tFock_k = (             Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] + \
                        np.conj(np.flip(Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)],(0,1))))/2
                
                Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] = tFock_k
                Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))

                tFock_k = (np.conj(np.flip(Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)],(0,1))) + \
                                           Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)])/2        
                
                Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))
                Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)] = tFock_k                

            case "su2_vc":
                tFock_k = (             Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] + \
                        np.conj(np.flip(Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)],(0,1))))/2
                
                Fock_k[:,:,int(0*N_band/4):int(1*N_band/4),int(0*N_band/4):int(1*N_band/4)] = tFock_k
                Fock_k[:,:,int(1*N_band/4):int(2*N_band/4),int(1*N_band/4):int(2*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))

                tFock_k = (np.conj(np.flip(Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)],(0,1))) + \
                                           Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)])/2        
                
                Fock_k[:,:,int(2*N_band/4):int(3*N_band/4),int(2*N_band/4):int(3*N_band/4)] = np.conj(np.flip(tFock_k,(0,1)))
                Fock_k[:,:,int(3*N_band/4):int(4*N_band/4),int(3*N_band/4):int(4*N_band/4)] = tFock_k                                 

        return  Fock_k.reshape(self.nk**2,N_band,N_band) #in meV
    
    def try_mkdir(self,fchar):
        try:
            mkdir(fchar)
        except:
            pass    
    
    def Mean_field_State(self):
        # 0. Define initial values:

        kx_,ky_=self.kx.reshape(-1,1),self.ky.reshape(-1,1)
        H_nonint=self.nonint_H_string(kx_,ky_,self.SOC,self.SOC_dir) #shape(Nk,N_band,N_band)
        E_old,Eigenstates=self.Eigen(self.nk,kx_,ky_,self.SOC,self.SOC_dir,0) #Eigenstates.shape (nk,nk,N_band,N_band)
        seed_E,seed_Eigenstates=self.Eigen(self.nk,kx_,ky_,self.SOC,self.SOC_dir,1) #shape (nk,nk,N_band,N_band)
        mu_old=self.get_mu(E_old,self.ne_val[0])
        rho_old=self.get_rho(seed_Eigenstates,self.Fermi_Dirac(seed_E,mu_old))
        rho_mix=rho_old

        resultpath = './result/sweep'

        if self.H_renorm == True:
            resultpath = resultpath + '_with_r'

        if not(self.run_idx==0):
            resultpath = resultpath + f'_{self.run_idx}'
        if not(self.calc_type == ''):
            resultpath = resultpath + "_" + self.calc_type

        if (self.rmt_SWM==0):
            resultpath2 = resultpath + '/no_rmt_hop_'
        else:
            resultpath2 = resultpath + '/'

        if self.V_type == "long":
            resultpath2 = resultpath2 + f'er={self.er:.1f}'
        else:
            nu_0 = self.SWM["gamma_1"] / (3 * np.pi * self.SWM["gamma_0"]**2 * 2.46e-8**2); # [meV^-1 * cm^-2]
            resultpath2 = resultpath2 + f'V0xnu0={nu_0*self.V0:.4f}'

        resultpath2 = resultpath2 + f'_alp={self.alp:.3f}'
        if not(self.Hund == 0):
            resultpath2 = resultpath2 + f'_Hund={self.Hund:.2f}'

        resultpath3 = resultpath2+ f'/U={self.U:.1f}_SOC={self.SOC:.2f}_{self.seed_val["K_up"]:.1f}_{self.seed_val["K_dn"]:.1f}'+\
                    f'_{self.seed_val["Kp_up"]:.1f}_{self.seed_val["Kp_dn"]:.1f}'
                
        if not(self.seed_val["vx"] == 0 and self.seed_val["sx"] == 0 and self.seed_val["vsx"] == 0):
            resultpath3 = resultpath3 + f'_{self.seed_val["vx"]:.1f}_{self.seed_val["sx"]:.1f}_{self.seed_val["vsx"]:.1f}'

        if not(self.SOC_dir == 0):
            resultpath3 = resultpath3 + f'_SOC_dir={self.SOC_dir:.1f}'

        if not(self.fix_sym == ''):
            resultpath3 = resultpath3 + "_" + self.fix_sym

        # Make output directory first

        self.try_mkdir('result')
        self.try_mkdir(resultpath)
        self.try_mkdir(resultpath2)
        self.try_mkdir(resultpath3)
        self.try_mkdir(resultpath3+'/band_spin')
        self.try_mkdir(resultpath3+'/band_valley')
        self.try_mkdir(resultpath3+'/data')
        self.try_mkdir(resultpath3+'/fock_0')         

        for ne in self.ne_val:
            Ne=ne*self.A #total number of charge        
            # 2.Final state and Fermi Surface
            count=0
            error=3.14
            initial_time=time.time()
            lap_time = initial_time
            while error > self.tol and count < self.max_count:
                H_SCHF=H_nonint +self.get_Fock(rho_mix)#+self.get_valley_ex(rho_mix)#+self.Lagrange_multiplier()
                assert H_SCHF.shape==(self.nk**2,N_band,N_band)
                E_SCHF,eigenvectors=np.linalg.eigh(H_SCHF)
                E_SCHF=E_SCHF.reshape(self.nk,self.nk,N_band)
                eigenvectors=eigenvectors.reshape(self.nk,self.nk,N_band,N_band)
                mu_new=self.get_mu(E_SCHF,ne)

                match self.calc_type:
                    case "CBE":
                        rho_new =self.get_rho(eigenvectors,self.Fermi_Dirac(E_SCHF,mu_new))
                        rho_mix=self.mix*rho_new+(1-self.mix)*rho_old
                        E_nonint = np.sum(np.einsum('ijmn,ijnm->ij',\
                                H_nonint.reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors[:,:,:,int(N_band/2):N_band],abs(self.Fermi_Dirac(E_SCHF[:,:,int(N_band/2):N_band],mu_new))))).real/abs(Ne)
                        
                        E_total  = np.sum(np.einsum('ijmn,ijnm->ij',\
                                (H_nonint+0.5*self.get_Fock(rho_mix)).reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors[:,:,:,int(N_band/2):N_band],abs(self.Fermi_Dirac(E_SCHF[:,:,int(N_band/2):N_band],mu_new))))).real/abs(Ne)                                        

                    case "HF_CB_only":
                        rho_new =self.get_rho(eigenvectors[:,:,:,int(N_band/2):N_band],self.Fermi_Dirac(E_SCHF[:,:,int(N_band/2):N_band],mu_new))
                        rho_mix=self.mix*rho_new+(1-self.mix)*rho_old
                        E_nonint = np.sum(np.einsum('ijmn,ijnm->ij',\
                                H_nonint.reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors[:,:,:,int(N_band/2):N_band],abs(self.Fermi_Dirac(E_SCHF[:,:,int(N_band/2):N_band],mu_new))))).real/abs(Ne)
                        
                        E_total  = np.sum(np.einsum('ijmn,ijnm->ij',\
                                (H_nonint+0.5*self.get_Fock(rho_mix)).reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors[:,:,:,int(N_band/2):N_band],abs(self.Fermi_Dirac(E_SCHF[:,:,int(N_band/2):N_band],mu_new))))).real/abs(Ne)                                                                

                    case _:
                        rho_new =self.get_rho(eigenvectors,self.Fermi_Dirac(E_SCHF,mu_new))
                        rho_mix=self.mix*rho_new+(1-self.mix)*rho_old
                        E_nonint = np.sum(np.einsum('ijmn,ijnm->ij',\
                                H_nonint.reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors,abs(self.Fermi_Dirac(E_SCHF,mu_new))))).real/abs(Ne)
                        
                        E_total  = np.sum(np.einsum('ijmn,ijnm->ij',\
                                (H_nonint+0.5*self.get_Fock(rho_mix)).reshape(self.nk,self.nk,N_band,N_band),\
                                self.get_rho(eigenvectors,abs(self.Fermi_Dirac(E_SCHF,mu_new))))).real/abs(Ne)                                                                


                # 3. Calculate Errors for each step:
                error=np.sum(np.linalg.norm((rho_new-rho_old),axis=(-2, -1)))/(N_band*self.nk)**2 #error in density matrix
                count+=1
                rho_mix=rho_new*self.mix+(1-self.mix)*rho_old # computational steps
                rho_old=rho_new
                E_old=E_SCHF
                mu_old=mu_new
                if count==self.max_count-1 and error>self.tol:
                        print('Warning: The ground state didn\'t converge for ne={ne}e12, U={U}'.format(ne=ne/1e12,U=self.U))
                # print(count,':','mu=',mu_new,'energy per charge=',E_total,'\n','error_rho=',error,flush=True)
                print(f'(U={self.U:.1f}, ne={ne/1e12:.2f}e12) Iter #{count:2.0f}: Err: {error:.1e} (Tot: {time.time()-initial_time:.1f}s, Lap: {time.time()-lap_time:.1f}s)',flush=True)
                lap_time = time.time()

            print(f'(U={self.U:.1f}, ne={ne/1e12:.2f}e12) done. [Iter: {count}]',flush=True)            

            '''********OUTPUTS**********************'''
            '''Plot the bands along ky=0 and Fermi surface(s)'''
            
            
            valley_x  = np.sum(np.real(np.einsum('ijkm,mk->ij',rho_new,V_x)))/abs(Ne)
            valley_s_x= np.sum(np.real(np.einsum('ijkm,mk->ij',rho_new,V_x@S_x)))/abs(Ne)
            valley=np.real(np.einsum('ijkm,kl,ijlm->ijm',np.conj(eigenvectors),V_z,eigenvectors)) #valley polarization


            if (np.round(valley_x,5)!=0 or np.round(valley_s_x,5)!=0 or np.sum(valley > 0)!=self.nk**2*int(N_band/2) or np.sum(valley < 0)!=self.nk**2*int(N_band/2) ) : #non Ising
                fig, ax = plt.subplots(1, 3, figsize=(20, 4))
                spin=np.real(np.einsum('ijkm,kl,ijlm->ijm',np.conj(eigenvectors),S_z,eigenvectors)) #spin polarization
                col=spin
                for i in range(0,N_band):
                    s=ax[0].scatter(self.kx[0], E_SCHF[int(self.nk/2),:,i], c=col[int(self.nk/2),:,i], cmap='jet', marker='.', s=1, vmin=-1, vmax=1)
                cbar = fig.colorbar(s,ax=ax[0],ticks=[-1, 0, 1])
                cbar.set_label('$S_z$')    
                ax[0].axhline(y=mu_old, color='gray', linestyle='--')
                ax[0].set_ylim(mu_new-8,mu_new+8)
                ax[0].set_xlim(-np.pi/40, np.pi/40)
                ax[0].set_xticks([-np.pi/40, 0, np.pi/40])
                ax[0].set_xticklabels([r'-$\frac{\pi}{40}$', r'$0$', r'$\frac{\pi}{40}$'])
                ax[0].set_xlabel('$k_x$')
                ax[0].set_ylabel('E (meV)')
                ax[0].set_title('$n_e=${ne:.2f}$\\times 10^{{11}}$ cm$^{{-2}}$, $U=${U:.1f} meV'.format(ne=ne/1e11, U=self.U))
                
                

                fig.subplots_adjust(wspace=0.6)
                im = ax[1].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(E_SCHF,mu_new),axis=2)-int(N_band/2)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=4)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K valley
                cbar = fig.colorbar(im,ax=ax[1]) # fix the colorbar to the second subplot
                ax[1].set_xlabel('$k_x$')
                ax[1].set_ylabel('$k_y$')
                ax[1].set_aspect('equal')

                ax[2].axis('off')

                resultpath4 = resultpath3 + f'/band_spin/ne={ne/1e12:.2f}e12.png'
                plt.savefig(resultpath4, transparent=True)
                plt.close()

                ####################
                
                fig, ax = plt.subplots(1, 3, figsize=(20, 4))
                col=valley
                for i in range(0,N_band):
                    s=ax[0].scatter(self.kx[0], E_SCHF[int(self.nk/2),:,i], c=col[int(self.nk/2),:,i], cmap='jet', marker='.', s=1, vmin=-1, vmax=1)
                cbar = fig.colorbar(s,ax=ax[0],ticks=[-1, 0, 1])
                cbar.set_label('$\\tau_z$')
                ax[0].axhline(y=mu_old, color='gray', linestyle='--')
                ax[0].set_ylim(mu_new-8,mu_new+8)
                ax[0].set_xlim(-np.pi/40, np.pi/40)
                ax[0].set_xticks([-np.pi/40, 0, np.pi/40])
                ax[0].set_xticklabels([r'-$\frac{\pi}{40}$', r'$0$', r'$\frac{\pi}{40}$'])
                ax[0].set_xlabel('$k_x$')
                ax[0].set_ylabel('E (meV)')
                ax[0].set_title('$n_e=${ne:.2f}$\\times 10^{{11}}$ cm$^{{-2}}$, $U=${U:.1f} meV'.format(ne=ne/1e11, U=self.U))

                fig.subplots_adjust(wspace=0.6)
                im = ax[1].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(E_SCHF,mu_new),axis=2)-int(N_band/2)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=4)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K valley
                cbar = fig.colorbar(im,ax=ax[1]) # fix the colorbar to the second subplot
                ax[1].set_xlabel('$k_x$')
                ax[1].set_ylabel('$k_y$')
                ax[1].set_aspect('equal')

                ax[2].axis('off')
                
                resultpath4 = resultpath3 + f'/band_valley/ne={ne/1e12:.2f}e12.png'
                plt.savefig(resultpath4, transparent=True)
                plt.close()

            else: #Ising
                Eigen_values_K1=(E_SCHF[np.where(valley > 0)]).reshape(self.nk,self.nk,int(N_band/2))
                Eigen_values_K2=(E_SCHF[np.where(valley < 0)]).reshape(self.nk,self.nk,int(N_band/2))                
                
                fig, ax = plt.subplots(1, 3, figsize=(20, 4))
                spin=np.real(np.einsum('ijkm,kl,ijlm->ijm',np.conj(eigenvectors),S_z,eigenvectors)) #spin polarization
                col=spin
                for i in range(0,N_band):
                    s=ax[0].scatter(self.kx[0], E_SCHF[int(self.nk/2),:,i], c=col[int(self.nk/2),:,i], cmap='jet', marker='.', s=1, vmin=-1, vmax=1)
                cbar = fig.colorbar(s,ax=ax[0],ticks=[-1, 0, 1])
                cbar.set_label('$S_z$')
                ax[0].axhline(y=mu_old, color='gray', linestyle='--')
                ax[0].set_ylim(mu_new-8,mu_new+8)
                ax[0].set_xlim(-np.pi/40, np.pi/40)
                ax[0].set_xticks([-np.pi/40, 0, np.pi/40])
                ax[0].set_xticklabels([r'-$\frac{\pi}{40}$', r'$0$', r'$\frac{\pi}{40}$'])
                ax[0].set_xlabel('$k_x$')
                ax[0].set_ylabel('E (meV)')
                ax[0].set_title('$n_e=${ne:.2f}$\\times 10^{{11}}$ cm$^{{-2}}$, $U=${U:.1f} meV'.format(ne=ne/1e11, U=self.U))
                # Adjust the horizontal space between the subplots
                fig.subplots_adjust(wspace=0.6)
                im = ax[1].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(Eigen_values_K1,mu_new),axis=2)-int(N_band/4)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=2)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K valley
                cbar = fig.colorbar(im,ax=ax[1]) # fix the colorbar to the second subplot
                ax[1].set_xlabel('$k_x$')
                ax[1].set_ylabel('$k_y$')
                ax[1].set_aspect('equal')
                ax[1].set_title('valley K')

                fig.subplots_adjust(wspace=0.6)
                im = ax[2].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(Eigen_values_K2,mu_new),axis=2)-int(N_band/4)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=2)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K' valley
                cbar = fig.colorbar(im,ax=ax[2]) # fix the colorbar to the second subplot
                ax[2].set_xlabel('$k_x$')
                ax[2].set_ylabel('$k_y$')
                ax[2].set_aspect('equal')
                ax[2].set_title('valley K\'')

                resultpath4 = resultpath3 + f'/band_spin/ne={ne/1e12:.2f}e12.png'
                plt.savefig(resultpath4, transparent=True)
                plt.close()

                #####################

                fig, ax = plt.subplots(1, 3, figsize=(20, 4))
                col=valley
                for i in range(0,N_band):
                    s=ax[0].scatter(self.kx[0], E_SCHF[int(self.nk/2),:,i], c=col[int(self.nk/2),:,i], cmap='jet', marker='.', s=1, vmin=-1, vmax=1)
                cbar = fig.colorbar(s,ax=ax[0],ticks=[-1, 0, 1])
                cbar.set_label('$\\tau_z$')
                ax[0].axhline(y=mu_old, color='gray', linestyle='--')
                ax[0].set_ylim(mu_new-8,mu_new+8)
                ax[0].set_xlim(-np.pi/40, np.pi/40)
                ax[0].set_xticks([-np.pi/40, 0, np.pi/40])
                ax[0].set_xticklabels([r'-$\frac{\pi}{40}$', r'$0$', r'$\frac{\pi}{40}$'])
                ax[0].set_xlabel('$k_x$')
                ax[0].set_ylabel('E (meV)')
                ax[0].set_title('$n_e=${ne:.2f}$\\times 10^{{11}}$ cm$^{{-2}}$, $U=${U:.1f} meV'.format(ne=ne/1e11, U=self.U))
                # Adjust the horizontal space between the subplots
                fig.subplots_adjust(wspace=0.6)
                im = ax[1].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(Eigen_values_K1,mu_new),axis=2)-int(N_band/4)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=2)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K valley
                cbar = fig.colorbar(im,ax=ax[1]) # fix the colorbar to the second subplot
                ax[1].set_xlabel('$k_x$')
                ax[1].set_ylabel('$k_y$')
                ax[1].set_aspect('equal')
                ax[1].set_title('valley K')

                fig.subplots_adjust(wspace=0.6)
                im = ax[2].pcolormesh(self.kx, self.ky,np.sum(self.Fermi_Dirac(Eigen_values_K2,mu_new),axis=2)-int(N_band/4)*np.ones([self.nk,self.nk]),shading='auto',cmap='inferno',vmin=0,vmax=2)#*0.5*(col_v[:,:,:int(N_band/2)]+1) for K' valley
                cbar = fig.colorbar(im,ax=ax[2]) # fix the colorbar to the second subplot
                ax[2].set_xlabel('$k_x$')
                ax[2].set_ylabel('$k_y$')
                ax[2].set_aspect('equal')
                ax[2].set_title('valley K\'')                

                resultpath4 = resultpath3 + f'/band_valley/ne={ne/1e12:.2f}e12.png'
                plt.savefig(resultpath4, transparent=True)
                plt.close()

            ii,jj=np.indices((self.nk,self.nk))
            ii,jj=np.reshape(ii,(-1,1)),np.reshape(jj,(-1,1))
            filled_band = np.max(np.apply_along_axis(lambda k: sum(E_SCHF[k[0]][k[1]]<mu_new), 0, (ii,jj))) - N_band/2

            resultpath4 = resultpath3 + f'/data/ne={ne/1e12:.2f}e12.dat'
            out_file = open(resultpath4,'w')
            out_file.write(f'{mu_new:.5f}\n{filled_band:.0f}\n{E_total:.5f}\n{E_nonint:.5f}\n')
            out_file.close()


            #shape(nk**2,N_band,N_band)
            H_Fock  = H_SCHF - H_nonint
            H_Fock  = H_Fock[int(self.nk**2/2)]
            H_SCHF0 = H_SCHF[int(self.nk**2/2)]
            
            # H_SCHF0 = H_Fock[25200]
            # H_Fock  = H_Fock[37800]
            
            resultpath4 = resultpath3 + f'/fock_0/ne={ne/1e12:.2f}e12.dat'
            out_file = open(resultpath4,'w')

            fock_diag = np.diag(np.real(H_Fock))-np.matrix.trace(np.real(H_Fock))/N_band
            formatted_array = [f"{number:.2f}" for number in fock_diag]
            out_file.write(' '.join(formatted_array)+'\n')
            out_file.write('\n')

            # Format and print each element to five decimal places
            formatted_array = [[f"{number:.2f}" for number in row] for row in np.real(H_SCHF0)]
            for row in formatted_array:
                out_file.write(' '.join(row)+'\n')

            out_file.write('\n')
            formatted_array = [[f"{number:.2f}" for number in row] for row in np.imag(H_SCHF0)]
            for row in formatted_array:
                out_file.write(' '.join(row)+'\n')        

            out_file.write('\n')
            formatted_array = [[f"{number:.2f}" for number in row] for row in np.real(H_Fock)]
            for row in formatted_array:
                out_file.write(' '.join(row)+'\n')

            out_file.write('\n')
            formatted_array = [[f"{number:.2f}" for number in row] for row in np.imag(H_Fock)]
            for row in formatted_array:
                out_file.write(' '.join(row)+'\n')        

            out_file.close()
        
        # return E_SCHF,eigenvectors,rho_new,mu_new,E_total